import math
import argparse
import torch
import numpy as np
import shutil
import os
import pickle
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
from utils.func import mkdir_p, AverageMeter

parser = argparse.ArgumentParser(description='Pytorch Variational Positive Unlabeled Learning')
parser.add_argument('--dataset', default='cifar10',
                    choices=['cifar10', 'fashionMNIST', 'stl10', 'avila', 'pageblocks', 'grid'])
parser.add_argument('--gpu', type=int, default=9)
parser.add_argument('--val-iterations', type=int, default=30)
parser.add_argument('--batch-size', type=int, default=500)
parser.add_argument('--num_labeled', type=int, default=3000)
parser.add_argument('--learning-rate', type=float, default=3e-5)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--mix-alpha', type=float, default=0.3)
parser.add_argument('--lam', type=float, default=0.03)

args = parser.parse_args()

if args.dataset == 'cifar10':
    from model.model_cifar import NetworkPhi
    from dataset.dataset_cifar import get_cifar10_loaders as get_loaders

    parser.add_argument('--positive_label_list', type=list, default=[0, 1, 8, 9])
elif args.dataset == 'fashionMNIST':
    from model.model_fashionMNIST import NetworkPhi
    from dataset.dataset_fashionMNIST import get_fashionMNIST_loaders as get_loaders

    parser.add_argument('--positive_label_list', type=list, default=[1, 4, 7])
elif args.dataset == 'stl10':
    from model.model_stl import NetworkPhi
    from dataset.dataset_stl import get_stl10_loaders as get_loaders

    parser.add_argument('--positive_label_list', type=list, default=[0, 2, 3, 8, 9])
elif args.dataset == 'pageblocks':
    from model.model_vec import NetworkPhi
    from dataset.dataset_pageblocks import get_pageblocks_loaders as get_loaders

    parser.add_argument('--positive_label_list', type=list, default=[2, 3, 4, 5])
elif args.dataset == 'grid':
    from model.model_vec import NetworkPhi
    from dataset.dataset_grid import get_grid_loaders as get_loaders

    parser.add_argument('--positive_label_list', type=list, default=[1])
elif args.dataset == 'avila':
    from model.model_vec import NetworkPhi
    from dataset.dataset_avila import get_avila_loaders as get_loaders

    parser.add_argument('--positive_label_list', type=list, default=['A'])
else:
    assert False
args = parser.parse_args()


def main(config):
    is_cuda = torch.cuda.is_available()
    if is_cuda:
        torch.cuda.set_device(config.gpu)

    if config.dataset in ['cifar10', 'fashionMNIST', 'stl10']:
        x_loader, p_loader, val_x_loader, val_p_loader, test_loader, idx = get_loaders(batch_size=config.batch_size,
                                                                                       num_labeled=config.num_labeled,
                                                                                       positive_label_list=config.positive_label_list)
    elif config.dataset in ['avila', 'pageblocks', 'grid']:
        x_loader, p_loader, val_x_loader, val_p_loader, test_loader = get_loaders(batch_size=config.batch_size,
                                                                                  num_labeled=config.num_labeled,
                                                                                  positive_label_list=config.positive_label_list)
    else:
        assert False

    loaders = (p_loader, x_loader, val_p_loader, val_x_loader, test_loader)

    print('==> Preparing data')
    print('    number of train data: ', len(x_loader.dataset))
    print('    number of labeled train data: ', len(p_loader.dataset))
    print('    number of test data: ', len(test_loader.dataset))
    print('    number of val x data:', len(val_x_loader.dataset))
    print('    number of val p data:', len(val_p_loader.dataset))

    checkpoint = get_checkpoint_path(config)
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)

    if config.dataset in ['cifar10', 'fashionMNIST', 'stl10']:
        filename = os.path.join(checkpoint, 'idx')
        with open(filename, 'wb') as file:
            pickle.dump(idx, file)
    elif config.dataset in ['avila', 'pageblocks', 'grid']:
        filename = os.path.join(checkpoint, 'loaders')
        with open(filename, 'wb') as file:
            pickle.dump(loaders, file)

    run_vpu(config, loaders, is_cuda)


def run_vpu(config, loaders, is_cuda):
    lowest_val_kl = math.inf
    highest_test_acc = -1
    (p_loader, x_loader, val_p_loader, val_x_loader, test_loader) = loaders

    lr_phi = config.learning_rate

    if config.dataset in ['cifar10', 'fashionMNIST', 'stl10']:
        model_phi = NetworkPhi()
    elif config.dataset in ['pageblocks', 'grid', 'avila']:
        input_size = len(p_loader.dataset[0][0])
        model_phi = NetworkPhi(input_size=input_size)
    model_phi = model_phi.cuda() if is_cuda else model_phi
    opt_phi = torch.optim.Adam(model_phi.parameters(), lr=lr_phi, betas=(0.5, 0.99))

    for epoch in range(config.epochs):

        if epoch % 20 == 19:
            lr_phi /= 2
            opt_phi = torch.optim.Adam(model_phi.parameters(), lr=lr_phi, betas=(0.5, 0.99))

        phi_loss, kl_loss, reg_loss, phi_p_mean_not_norm, phi_x_mean_not_norm = train(config, model_phi,
                                                                                      opt_phi,
                                                                                      p_loader, x_loader,
                                                                                      is_cuda)

        if epoch % 1 == 0:

            val_kl, test_acc, test_auc = evaluate(model_phi, x_loader, test_loader, val_p_loader, val_x_loader, epoch,
                                                  phi_p_mean_not_norm, phi_x_mean_not_norm, phi_loss, kl_loss, reg_loss,
                                                  is_cuda)

            is_val_kl_lowest = val_kl < lowest_val_kl
            is_test_acc_highest = test_acc > highest_test_acc
            lowest_val_kl = min(lowest_val_kl, val_kl)
            highest_test_acc = max(highest_test_acc, test_acc)
            if is_val_kl_lowest:
                test_auc_of_best_val = test_auc
                test_acc_of_best_val = test_acc
                epoch_of_best_val = epoch
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model_phi.state_dict(),
                'optimizer': opt_phi.state_dict(),
            }, is_val_kl_lowest, is_test_acc_highest, config=config)

    print('early stopping at {:}th epoch, test AUC : {:.4f}, test acc: {:.4f}'.format(epoch_of_best_val,
                                                                                      test_auc_of_best_val,
                                                                                      test_acc_of_best_val))

def train(config, model_phi, opt_phi, p_loader, x_loader, is_cuda):
    phi_p_avg = AverageMeter()
    phi_x_avg = AverageMeter()
    phi_loss_avg = AverageMeter()
    kl_loss_avg = AverageMeter()
    reg_margin_avg = AverageMeter()
    model_phi.train()

    for batch_idx in range(config.val_iterations):

        try:
            data_x, _ = x_iter.next()
        except:
            x_iter = iter(x_loader)
            data_x, _ = x_iter.next()

        try:
            data_p, _ = p_iter.next()
        except:
            p_iter = iter(p_loader)
            data_p, _ = p_iter.next()

        if is_cuda:
            data_p, data_x = data_p.cuda(), data_x.cuda()

        data_all = torch.cat((data_p, data_x))
        output_phi_all = model_phi(data_all)
        log_phi_all = output_phi_all[:, 1]
        idx_p = slice(0, len(data_p))
        idx_x = slice(len(data_p), len(data_all))
        log_phi_x = log_phi_all[idx_x]
        log_phi_p = log_phi_all[idx_p]
        output_phi_p = output_phi_all[idx_p]
        output_phi_x = output_phi_all[idx_x]

        kl_loss = torch.logsumexp(log_phi_x, dim=0) - math.log(len(log_phi_x)) - 1 * torch.mean(log_phi_p)

        target_x = output_phi_x[:, 1].exp()
        target_p = torch.ones(len(data_p), dtype=torch.float32)
        target_p = target_p.cuda() if is_cuda else target_p
        assert len(data_p) == len(data_x)
        rand_perm = torch.randperm(data_p.size(0))
        data_p_perm, target_p_perm = data_p[rand_perm], target_p[rand_perm]
        m = torch.distributions.beta.Beta(config.mix_alpha, config.mix_alpha)
        lam = m.sample()
        data = lam * data_x + (1 - lam) * data_p_perm
        target = lam * target_x + (1 - lam) * target_p_perm

        if is_cuda:
            data = data.cuda()
            target = target.cuda()
        out_log_phi_all = model_phi(data)
        reg_mix_log = ((torch.log(target) - out_log_phi_all[:, 1]) ** 2).mean()

        phi_loss = kl_loss + config.lam * reg_mix_log
        opt_phi.zero_grad()
        phi_loss.backward()
        opt_phi.step()

        reg_margin_avg.update(reg_mix_log.item())
        phi_loss_avg.update(phi_loss.item())
        kl_loss_avg.update(kl_loss.item())
        phi_p, phi_x = log_phi_p.exp(), log_phi_x.exp()
        phi_p_avg.update(phi_p.mean().item(), len(phi_p))
        phi_x_avg.update(phi_x.mean().item(), len(phi_x))

    return phi_loss_avg.avg, kl_loss_avg.avg, reg_margin_avg.avg, phi_p_avg.avg, phi_x_avg.avg




def evaluate(model_phi, x_loader, test_loader, val_p_loader, val_x_loader, epoch, phi_p_mean_not_norm,
             phi_x_mean_not_norm, phi_loss, kl_loss, reg_loss, is_cuda):
    model_phi.eval()
    print('Train Epoch: {}\tphi_loss: {:.4f}\tkl_loss: {:.4f}\treg_loss: {:.4f}'.format(epoch, phi_loss, kl_loss,
                                                                                        reg_loss))
    val_kl = cal_val_kl(model_phi, val_p_loader, val_x_loader, is_cuda)
    print('val kl: {:.4f}'.format(val_kl))

    log_max_phi = -math.inf
    for idx, (data, _) in enumerate(x_loader):
        if is_cuda:
            data = data.cuda()
        log_max_phi = max(log_max_phi, model_phi(data)[:, 1].max())

    phi_p_mean, phi_x_mean = phi_p_mean_not_norm / math.exp(log_max_phi), phi_x_mean_not_norm / math.exp(log_max_phi)
    print('Train p data phi_mean: {} \tTrain x data phi_mean: {} \tTrain p-x data phi_mean: {}'.format(
        phi_p_mean, phi_x_mean, phi_p_mean - phi_x_mean))

    test_acc, test_auc = validate(model_phi, test_loader, 'test', log_max_phi, is_cuda)

    return val_kl, test_acc, test_auc


def validate(model_phi, loader, mode, log_max_phi, is_cuda):
    assert mode in ['train', 'validation', 'test']
    model_phi.eval()
    with torch.no_grad():
        for idx, (data, target) in enumerate(loader):
            if is_cuda:
                data, target = data.cuda(), target.cuda()
            log_phi = model_phi(data)[:, 1]
            log_phi -= log_max_phi
            if idx == 0:
                log_phi_all = log_phi
                target_all = target
            else:
                log_phi_all = torch.cat((log_phi_all, log_phi))
                target_all = torch.cat((target_all, target))
    pred_all = np.array((log_phi_all > math.log(0.5)).cpu().detach())
    log_phi_all = np.array(log_phi_all.cpu().detach())
    target_all = np.array(target_all.cpu().detach())
    conf_mat = confusion_matrix(target_all, pred_all)
    accuracy = np.diag(conf_mat).sum() / conf_mat.sum()
    print(mode + ' accuracy: {:.4f}'.format(accuracy))
    if mode == 'test':
        auc = roc_auc_score(target_all, log_phi_all)
        print('test AUC: {:.4f}'.format(auc))
        print('test confusion matrix:\n', conf_mat, '\n')
        return accuracy, auc


def cal_val_kl(model_phi, val_p_loader, val_x_loader, is_cuda):
    model_phi.eval()

    with torch.no_grad():
        for idx, (data_x, _) in enumerate(val_x_loader):
            if is_cuda:
                data_x = data_x.cuda()
            output_phi_x_curr = model_phi(data_x)
            if idx == 0:
                output_phi_x = output_phi_x_curr
            else:
                output_phi_x = torch.cat((output_phi_x, output_phi_x_curr))
        for idx, (data_p, _) in enumerate(val_p_loader):
            if is_cuda:
                data_p = data_p.cuda()
            output_phi_p_curr = model_phi(data_p)
            if idx == 0:
                output_phi_p = output_phi_p_curr
            else:
                output_phi_p = torch.cat((output_phi_p, output_phi_p_curr))
        log_phi_p = output_phi_p[:, 1]
        log_phi_x = output_phi_x[:, 1]
        kl_loss = torch.logsumexp(log_phi_x, dim=0) - math.log(len(log_phi_x)) - torch.mean(log_phi_p)
        return kl_loss.item()


def get_checkpoint_path(config):
    checkpoint_path = os.path.join(os.getcwd(), config.dataset, 'P=' + str(config.positive_label_list),
                                   'lr=' + str(config.learning_rate), 'lambda=' + str(config.lam),
                                   'alpha=' + str(config.mix_alpha))
    return checkpoint_path


def save_checkpoint(state, is_lowest_on_val, is_highest_on_test, config, filename='checkpoint.pth.tar'):
    checkpoint = get_checkpoint_path(config)
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_lowest_on_val:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_lowest_on_val.pth.tar'))
    if is_highest_on_test:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_highest_on_test.pth.tar'))


if __name__ == '__main__':
    main(args)
